{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pLOYL1PJAAtK"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Hub Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3fJWQ8WSAFhh"
      },
      "outputs": [],
      "source": [
        "# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cd1dhL4Ykbm7"
      },
      "source": [
        "# Generating Images with BigGAN\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/hub/tutorials/biggan_generation_with_tf_hub\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/hub/tutorials/biggan_generation_with_tf_hub.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/hub/tutorials/biggan_generation_with_tf_hub.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/hub/tutorials/biggan_generation_with_tf_hub.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://tfhub.dev/s?q=deepmind%2Fbiggan\"><img src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" />See TF Hub models</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-1NTVIH6ABK-"
      },
      "source": [
        "This notebook is a demo for the *BigGAN* image generators available on [TF Hub](https://tfhub.dev/s?publisher=deepmind&q=biggan).\n",
        "\n",
        "See the [BigGAN paper on arXiv](https://arxiv.org/abs/1809.11096) [1] for more information about these models.\n",
        "\n",
        "After connecting to a runtime, get started by following these instructions:\n",
        "\n",
        "1. (Optional) Update the selected **`module_path`** in the first code cell below to load a BigGAN generator for a different image resolution.\n",
        "2. Click **Runtime > Run all** to run each cell in order.\n",
        "  * Afterwards, the interactive visualizations should update automatically when you modify the settings using the sliders and dropdown menus.\n",
        "  * If not, press the **Play** button by the cell to re-render outputs manually.\n",
        "\n",
        "Note: if you run into any issues, it can help to click **Runtime > Restart and run all...** to restart your runtime and rerun all cells from scratch.\n",
        "\n",
        "[1] Andrew Brock, Jeff Donahue, and Karen Simonyan. [Large Scale GAN Training for High Fidelity Natural Image Synthesis](https://arxiv.org/abs/1809.11096). *arxiv:1809.11096*, 2018."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XS1_N6hKj8cz"
      },
      "source": [
        "First, set the module path.\n",
        "By default, we load the BigGAN-deep generator for 256x256 images from **`https://tfhub.dev/deepmind/biggan-deep-256/1`**.\n",
        "To generate 128x128 or 512x512 images or to use the original BigGAN generators, comment out the active **`module_path`** setting and uncomment one of the others."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OJCIhQPClKJ1"
      },
      "outputs": [],
      "source": [
        "# BigGAN-deep models\n",
        "# module_path = 'https://tfhub.dev/deepmind/biggan-deep-128/1'  # 128x128 BigGAN-deep\n",
        "module_path = 'https://tfhub.dev/deepmind/biggan-deep-256/1'  # 256x256 BigGAN-deep\n",
        "# module_path = 'https://tfhub.dev/deepmind/biggan-deep-512/1'  # 512x512 BigGAN-deep\n",
        "\n",
        "# BigGAN (original) models\n",
        "# module_path = 'https://tfhub.dev/deepmind/biggan-128/2'  # 128x128 BigGAN\n",
        "# module_path = 'https://tfhub.dev/deepmind/biggan-256/2'  # 256x256 BigGAN\n",
        "# module_path = 'https://tfhub.dev/deepmind/biggan-512/2'  # 512x512 BigGAN"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JJrTM6hAi0CJ"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lOZnst2jeWDL"
      },
      "outputs": [],
      "source": [
        "import tensorflow.compat.v1 as tf\n",
        "tf.disable_v2_behavior()\n",
        "\n",
        "import os\n",
        "import io\n",
        "import IPython.display\n",
        "import numpy as np\n",
        "import PIL.Image\n",
        "from scipy.stats import truncnorm\n",
        "import tensorflow_hub as hub\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "stWb21nlcyCm"
      },
      "source": [
        "## Load a BigGAN generator module from TF Hub"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tVgwgJiCH3PV"
      },
      "outputs": [],
      "source": [
        "tf.reset_default_graph()\n",
        "print('Loading BigGAN module from:', module_path)\n",
        "module = hub.Module(module_path)\n",
        "inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)\n",
        "          for k, v in module.get_input_info_dict().items()}\n",
        "output = module(inputs)\n",
        "\n",
        "print()\n",
        "print('Inputs:\\n', '\\n'.join(\n",
        "    '  {}: {}'.format(*kv) for kv in inputs.items()))\n",
        "print()\n",
        "print('Output:', output)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ry62-8SWfuds"
      },
      "source": [
        "## Define some functions for sampling and displaying BigGAN images"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "46M8prJPDEsV"
      },
      "outputs": [],
      "source": [
        "input_z = inputs['z']\n",
        "input_y = inputs['y']\n",
        "input_trunc = inputs['truncation']\n",
        "\n",
        "dim_z = input_z.shape.as_list()[1]\n",
        "vocab_size = input_y.shape.as_list()[1]\n",
        "\n",
        "def truncated_z_sample(batch_size, truncation=1., seed=None):\n",
        "  state = None if seed is None else np.random.RandomState(seed)\n",
        "  values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state)\n",
        "  return truncation * values\n",
        "\n",
        "def one_hot(index, vocab_size=vocab_size):\n",
        "  index = np.asarray(index)\n",
        "  if len(index.shape) == 0:\n",
        "    index = np.asarray([index])\n",
        "  assert len(index.shape) == 1\n",
        "  num = index.shape[0]\n",
        "  output = np.zeros((num, vocab_size), dtype=np.float32)\n",
        "  output[np.arange(num), index] = 1\n",
        "  return output\n",
        "\n",
        "def one_hot_if_needed(label, vocab_size=vocab_size):\n",
        "  label = np.asarray(label)\n",
        "  if len(label.shape) <= 1:\n",
        "    label = one_hot(label, vocab_size)\n",
        "  assert len(label.shape) == 2\n",
        "  return label\n",
        "\n",
        "def sample(sess, noise, label, truncation=1., batch_size=8,\n",
        "           vocab_size=vocab_size):\n",
        "  noise = np.asarray(noise)\n",
        "  label = np.asarray(label)\n",
        "  num = noise.shape[0]\n",
        "  if len(label.shape) == 0:\n",
        "    label = np.asarray([label] * num)\n",
        "  if label.shape[0] != num:\n",
        "    raise ValueError('Got # noise samples ({}) != # label samples ({})'\n",
        "                     .format(noise.shape[0], label.shape[0]))\n",
        "  label = one_hot_if_needed(label, vocab_size)\n",
        "  ims = []\n",
        "  for batch_start in range(0, num, batch_size):\n",
        "    s = slice(batch_start, min(num, batch_start + batch_size))\n",
        "    feed_dict = {input_z: noise[s], input_y: label[s], input_trunc: truncation}\n",
        "    ims.append(sess.run(output, feed_dict=feed_dict))\n",
        "  ims = np.concatenate(ims, axis=0)\n",
        "  assert ims.shape[0] == num\n",
        "  ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255)\n",
        "  ims = np.uint8(ims)\n",
        "  return ims\n",
        "\n",
        "def interpolate(A, B, num_interps):\n",
        "  if A.shape != B.shape:\n",
        "    raise ValueError('A and B must have the same shape to interpolate.')\n",
        "  alphas = np.linspace(0, 1, num_interps)\n",
        "  return np.array([(1-a)*A + a*B for a in alphas])\n",
        "\n",
        "def imgrid(imarray, cols=5, pad=1):\n",
        "  if imarray.dtype != np.uint8:\n",
        "    raise ValueError('imgrid input imarray must be uint8')\n",
        "  pad = int(pad)\n",
        "  assert pad >= 0\n",
        "  cols = int(cols)\n",
        "  assert cols >= 1\n",
        "  N, H, W, C = imarray.shape\n",
        "  rows = N // cols + int(N % cols != 0)\n",
        "  batch_pad = rows * cols - N\n",
        "  assert batch_pad >= 0\n",
        "  post_pad = [batch_pad, pad, pad, 0]\n",
        "  pad_arg = [[0, p] for p in post_pad]\n",
        "  imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)\n",
        "  H += pad\n",
        "  W += pad\n",
        "  grid = (imarray\n",
        "          .reshape(rows, cols, H, W, C)\n",
        "          .transpose(0, 2, 1, 3, 4)\n",
        "          .reshape(rows*H, cols*W, C))\n",
        "  if pad:\n",
        "    grid = grid[:-pad, :-pad]\n",
        "  return grid\n",
        "\n",
        "def imshow(a, format='png', jpeg_fallback=True):\n",
        "  a = np.asarray(a, dtype=np.uint8)\n",
        "  data = io.BytesIO()\n",
        "  PIL.Image.fromarray(a).save(data, format)\n",
        "  im_data = data.getvalue()\n",
        "  try:\n",
        "    disp = IPython.display.display(IPython.display.Image(im_data))\n",
        "  except IOError:\n",
        "    if jpeg_fallback and format != 'jpeg':\n",
        "      print(('Warning: image was too large to display in format \"{}\"; '\n",
        "             'trying jpeg instead.').format(format))\n",
        "      return imshow(a, format='jpeg')\n",
        "    else:\n",
        "      raise\n",
        "  return disp"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uCeCg3Sdf8Nv"
      },
      "source": [
        "## Create a TensorFlow session and initialize variables"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rYJor5bOaVn1"
      },
      "outputs": [],
      "source": [
        "initializer = tf.global_variables_initializer()\n",
        "sess = tf.Session()\n",
        "sess.run(initializer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SeZ7u3rWd9jz"
      },
      "source": [
        "# Explore BigGAN samples of a particular category\n",
        "\n",
        "Try varying the **`truncation`** value.\n",
        "\n",
        "(Double-click on the cell to view code.)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HuCO9tv3IKT2"
      },
      "outputs": [],
      "source": [
        "#@title Category-conditional sampling { display-mode: \"form\", run: \"auto\" }\n",
        "\n",
        "num_samples = 10 #@param {type:\"slider\", min:1, max:20, step:1}\n",
        "truncation = 0.4 #@param {type:\"slider\", min:0.02, max:1, step:0.02}\n",
        "noise_seed = 0 #@param {type:\"slider\", min:0, max:100, step:1}\n",
        "category = \"933) cheeseburger\"\n", 
        "\n",
        "z = truncated_z_sample(num_samples, truncation, noise_seed)\n",
        "y = int(category.split(')')[0])\n",
        "\n",
        "ims = sample(sess, z, y, truncation=truncation)\n",
        "imshow(imgrid(ims, cols=min(num_samples, 5)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hHNXtvuQgKwa"
      },
      "source": [
        "# Interpolate between BigGAN samples\n",
        "\n",
        "Try setting different **`category`**s  with the same **`noise_seed`**s, or the same **`category`**s with different **`noise_seed`**s. Or go wild and set both any way you like!\n",
        "\n",
        "(Double-click on the cell to view code.)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dSAyfDfnVugs"
      },
      "outputs": [],
      "source": [
        "#@title Interpolation { display-mode: \"form\", run: \"auto\" }\n",
        "\n",
        "num_samples = 2 #@param {type:\"slider\", min:1, max:5, step:1}\n",
        "num_interps = 5 #@param {type:\"slider\", min:2, max:10, step:1}\n",
        "truncation = 0.2 #@param {type:\"slider\", min:0.02, max:1, step:0.02}\n",
        "noise_seed_A = 0 #@param {type:\"slider\", min:0, max:100, step:1}\n",
        "category_A = \"207) golden retriever\"\n",
        "noise_seed_B = 0 #@param {type:\"slider\", min:0, max:100, step:1}\n",
        "category_B = \"8) hen\"\n",
        "\n",
        "def interpolate_and_shape(A, B, num_interps):\n",
        "  interps = interpolate(A, B, num_interps)\n",
        "  return (interps.transpose(1, 0, *range(2, len(interps.shape)))\n",
        "                 .reshape(num_samples * num_interps, *interps.shape[2:]))\n",
        "\n",
        "z_A, z_B = [truncated_z_sample(num_samples, truncation, noise_seed)\n",
        "            for noise_seed in [noise_seed_A, noise_seed_B]]\n",
        "y_A, y_B = [one_hot([int(category.split(')')[0])] * num_samples)\n",
        "            for category in [category_A, category_B]]\n",
        "\n",
        "z_interp = interpolate_and_shape(z_A, z_B, num_interps)\n",
        "y_interp = interpolate_and_shape(y_A, y_B, num_interps)\n",
        "\n",
        "ims = sample(sess, z_interp, y_interp, truncation=truncation)\n",
        "imshow(imgrid(ims, cols=num_interps))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "pLOYL1PJAAtK"
      ],
      "name": "biggan_generation_with_tf_hub.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
